import sounddevice as sd
import scipy.io.wavfile as wav
import os
import curses
import pickle
import sys
from pydub import AudioSegment

curses.initscr()
curses.noecho()

class Win:
    origin_win = curses.newwin(curses.LINES, curses.COLS, 0, 0)

    def __init__(self, hei, wid, y, x):
        self.win = self.origin_win.subwin(hei, wid, y, x)
        self.hei = hei
        self.wid = wid
        self.y = y
        self.x = x

    def build(self):
        pass

    def refresh(self, *args, **kwargs):
        self.win.refresh(*args, **kwargs)

    def draw(self):
        self.erase()
        self.build()
        self.refresh()

    def addstr(self, *args, **kwargs):
        return self.win.addstr(*args, **kwargs)

    def attron(self, *args, **kwargs):
        return self.win.attron(*args, **kwargs)

    def getch(self, *args, **kwargs):
        return self.win.getch(*args, **kwargs)

    def getkey(self, *args, **kwargs):
        return self.win.getkey(*args, **kwargs)

    def erase(self, *args, **kwargs):
        return self.win.erase(*args, **kwargs)

    def box(self, *args, **kwargs):
        return self.win.box(*args, **kwargs)

win = Win(curses.LINES, curses.COLS, 0, 0)


class File:
    def __init__(self, file_path):
        self.path = file_path
        dir, name = os.path.split(file_path)
        self.name = name
        self.dir = dir
        self.dB = AudioSegment.from_file(file_path).dBFS
        self.size = os.path.getsize(file_path)
    

class FileList:
    def __init__(self, files, sort_type):
        self.files = files
        self.idx = 0
        self.sort_type = sort_type
        
    def refresh(self): 
        i = 0
        while i < len(self.files):
            file = self.files[i]
            if not os.path.exists(file.path):
                self.files.pop(i)
                if i <= self.idx:
                    self.idx = max(0, self.idx - 1)
            else:
                i += 1 


    def next(self):
        self.idx = min(self.length - 1, self.idx + 1)

    def last(self):
        self.idx = max(0, self.idx - 1)

    def pop(self):
        cur_file = self.files.pop(self.idx)
        self.idx = min(self.length - 1, self.idx)
        return cur_file
    
    @property
    def cur(self):
        return self.files[self.idx]
    
    @property
    def length(self):
        return len(self.files)
        


class CleanHelper:
    def __init__(self, target_dir):
        self.target_dir = target_dir
        parent_dir, base_name = os.path.split(target_dir)

        trash_dir = os.path.join(parent_dir, 'trash')
        if not os.path.exists(trash_dir):
            os.mkdir(trash_dir)

        self.trash_dir = os.path.join(trash_dir, f'{base_name}')
        if not os.path.exists(self.trash_dir):
            os.mkdir(self.trash_dir)

        files = [File(os.path.join(target_dir, f)) for f in os.listdir(target_dir) if f.endswith('.wav')]

        self.default_list = FileList(files, '默认排序')

        files = sorted(files, key=lambda f:f.dB)
        self.dB_list = FileList(files, '音量大小')

        files = sorted(files, key=lambda f:f.size)
        self.size_list = FileList(files, '文件大小')

        files = sorted(files, key=lambda f:f.name)
        self.name_list = FileList(files, '文件名')

        self.cur_list = self.default_list
        # self.key = None

    def play(self):
        cur_file_path = self.cur_list.cur.path

        if os.path.exists(cur_file_path):
            fs, data = wav.read(cur_file_path)
            sd.play(data, fs)
    
    def stop(self):
        sd.stop()


    def exit(self):
        curses.echo()
        curses.nocbreak()
        curses.endwin()
    
    def delete(self):
        self.stop()

        # delete file from file list
        cur_file = self.cur_list.pop()

        #delete file from director
        source_path = os.path.join(self.target_dir, cur_file.name)
        save_path = os.path.join(self.trash_dir, cur_file.name)
        if os.path.exists(source_path):
            cmd = f'mv "{source_path}" "{save_path}"'
            os.system(cmd)

        self.draw()
        self.play()

    
    def last(self):
        self.stop()
        self.cur_list.last()
        self.draw()
        self.play()

    def next(self):
        self.stop()
        self.cur_list.next()
        self.draw()
        self.play()

    def draw(self):
        win.erase()
        h = win.hei//2
        w = win.wid//2

        text = f'{self.cur_list.idx+1} / {self.cur_list.length}'
        win.addstr(h - 2, w - len(text)//2 , text)

        cur = self.cur_list.cur
        win.addstr(h + 1, 3, f'当前文件名：{cur.name}')
        size = round(cur.size/1024, 2)
        win.addstr(h + 3, 3, f'当前文件大小：{size}kb')
        dB = round(cur.dB, 2)
        win.addstr(h + 5, 3, f'当前文件音量：{dB}dB')
        win.addstr(h + 7, 3, f'当前排序：{self.cur_list.sort_type}')

        # win.addstr(h + 7, 3, f'{self.key}')
        win.addstr(win.hei - 4, 5, '← 上一首       ↓ 删除         → 下一首' )
        win.addstr(win.hei - 3, 5, 'd 默认排序     v 按音量排序    s 按大小排序    n 按名字排序' )
        win.addstr(win.hei - 2, 5, 'q 退出         r 更新列表      < 第一个文件    > 最后一个文件' )


    def getch(self):
        return win.getch()

    def filter(self):
        pass

    def sort_default(self):
        self.cur_list = self.default_list
        self.draw()
    
    def sort_by_dB(self):
        self.cur_list = self.dB_list
        self.draw()

    def sort_by_size(self):
        self.cur_list = self.size_list
        self.draw()

    def sort_by_name(self):
        self.cur_list = self.name_list
        self.draw()
    
    def go_end(self):
        self.stop()
        self.cur_list.idx = self.cur_list.length - 1
        self.draw()
        self.play()

    def go_sta(self):
        self.stop()
        self.cur_list.idx = 0
        self.draw()
        self.play()

    def refresh(self):
        self.cur_list.refresh()
        self.draw()

    def loop(self):
        self.draw()
        while True:
            key = self.getch()
            self.key = key
            if chr(key) == 'q':
                self.exit()
                break
            elif chr(key) == 'd':
                self.sort_default()
            elif chr(key) == 'v':
                self.sort_by_dB()
            elif chr(key) == 's':
                self.sort_by_size()
            elif chr(key) == 'n':
                self.sort_by_name()
            elif chr(key) == 'r':
                self.refresh()
            elif chr(key) == '<':
                self.go_sta()
            elif chr(key) == '>':
                self.go_end()
            elif key == 66:
                self.delete()
            elif key == 68:
                self.last()
            elif key == 67:
                self.next()
            # else:
            #     self.draw()

def dump(state_path, cleaner):
    with open(state_path, 'wb') as f:
        pickle.dump(cleaner, f, pickle.HIGHEST_PROTOCOL)

def load(state_path):
    with open(state_path, 'rb') as f:
        cleaner = pickle.load(f)
    return cleaner


def main():
    target_dir = sys.argv[-1]

    if target_dir[-1] == '/':
        target_dir = target_dir[:-1]

    state_path = os.path.join(target_dir, 'state.pkl')

    if os.path.exists(state_path):
        cleaner = load(state_path)
    else:
        cleaner = CleanHelper(target_dir)
    cleaner.loop()
    dump(state_path, cleaner)

if __name__ == '__main__':
    main()
